package com.jstarcraft.ai.jsat.regression;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import com.jstarcraft.ai.jsat.classifiers.CategoricalData;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.distributions.kernels.KernelTrick;
import com.jstarcraft.ai.jsat.linear.DenseMatrix;
import com.jstarcraft.ai.jsat.linear.DenseVector;
import com.jstarcraft.ai.jsat.linear.Matrix;
import com.jstarcraft.ai.jsat.linear.SubMatrix;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.parameters.Parameter.ParameterHolder;
import com.jstarcraft.ai.jsat.parameters.Parameterized;
import com.jstarcraft.ai.jsat.utils.ListUtils;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import it.unimi.dsi.fastutil.ints.IntArrayList;

/**
 * Provides an implementation of the Kernel Recursive Least Squares algorithm.
 * This algorithm updates the model one per data point, and induces sparsity by
 * projecting data points down onto a set of basis vectors learned from the data
 * stream. <br>
 * <br>
 * See: Engel, Y., Mannor, S.,&amp;Meir, R. (2004). <i>The Kernel Recursive
 * Least-Squares Algorithm</i>. IEEE Transactions on Signal Processing, 52(8),
 * 2275–2285. doi:10.1109/TSP.2004.830985
 * 
 * @author Edward Raff
 */
public class KernelRLS implements UpdateableRegressor, Parameterized {

    private static final long serialVersionUID = -7292074388953854317L;
    @ParameterHolder
    private KernelTrick k;
    private double errorTolerance;

    private List<Vec> vecs;
    private DoubleArrayList kernelAccel;
    private Matrix K;
    private Matrix InvK;
    private Matrix P;

    private Matrix KExpanded;
    private Matrix InvKExpanded;
    private Matrix PExpanded;
    private double[] alphaExpanded;

    /**
     * Creates a new Kernel RLS learner
     * 
     * @param k              the kernel trick to use
     * @param errorTolerance the tolerance for errors in the projection
     */
    public KernelRLS(KernelTrick k, double errorTolerance) {
        this.k = k;
        setErrorTolerance(errorTolerance);
    }

    /**
     * Copy constructor
     * 
     * @param toCopy the object to copy
     */
    protected KernelRLS(KernelRLS toCopy) {
        this.k = toCopy.k.clone();
        this.errorTolerance = toCopy.errorTolerance;
        if (toCopy.vecs != null) {
            this.vecs = new ArrayList<Vec>(toCopy.vecs.size());
            for (Vec vec : toCopy.vecs)
                this.vecs.add(vec.clone());
        }

        if (toCopy.KExpanded != null) {
            this.KExpanded = toCopy.KExpanded.clone();
            this.K = new SubMatrix(KExpanded, 0, 0, vecs.size(), vecs.size());
        }
        if (toCopy.InvKExpanded != null) {
            this.InvKExpanded = toCopy.InvKExpanded.clone();
            this.InvK = new SubMatrix(InvKExpanded, 0, 0, vecs.size(), vecs.size());
        }
        if (toCopy.PExpanded != null) {
            this.PExpanded = toCopy.PExpanded.clone();
            this.P = new SubMatrix(PExpanded, 0, 0, vecs.size(), vecs.size());
        }
        if (toCopy.alphaExpanded != null)
            this.alphaExpanded = Arrays.copyOf(toCopy.alphaExpanded, toCopy.alphaExpanded.length);
    }

    /**
     * Sets the tolerance for errors in approximating a data point by projecting it
     * onto the set of basis vectors. In general: as the tolerance increases the
     * sparsity of the model increases but the accuracy may go down. <br>
     * Values in the range 10<sup>x</sup> &forall; x &isin; {-1, -2, -3, -4} often
     * work well for this algorithm.
     * 
     * @param v the approximation tolerance
     */
    public void setErrorTolerance(double v) {
        if (Double.isNaN(v) || Double.isInfinite(v) || v <= 0)
            throw new IllegalArgumentException("The error tolerance must be a positive constant, not " + v);
        this.errorTolerance = v;
    }

    /**
     * Returns the projection approximation tolerance
     * 
     * @return the projection approximation tolerance
     */
    public double getErrorTolerance() {
        return errorTolerance;
    }

    /**
     * Returns the number of basis vectors that make up the model
     * 
     * @return the number of basis vectors that make up the model
     */
    public int getModelSize() {
        if (vecs == null)
            return 0;
        return vecs.size();
    }

    /**
     * Finalizes the model. During online training, the the gram matrix and its
     * inverse must be stored to perform updates, at the cost of O(n<sup>2</sup>)
     * memory. One training is completed, these matrices are no longer needed - and
     * can be removed to reclaim memory by finalizing the model. Once finalized, the
     * model can no longer be updated - unless reset (destroying the model) by
     * calling
     * {@link #setUp(com.jstarcraft.ai.jsat.classifiers.CategoricalData[], int) }
     */
    public void finalizeModel() {
        alphaExpanded = Arrays.copyOf(alphaExpanded, vecs.size());// dont need extra
        K = KExpanded = InvK = InvKExpanded = P = PExpanded = null;
    }

    @Override
    public double regress(DataPoint data) {
        final Vec y = data.getNumericalValues();

        return k.evalSum(vecs, kernelAccel, alphaExpanded, y, 0, vecs.size());
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        setUp(dataSet.getCategories(), dataSet.getNumNumericalVars());
        IntArrayList randOrder = new IntArrayList(dataSet.size());
        ListUtils.addRange(randOrder, 0, dataSet.size(), 1);
        for (int i : randOrder)
            update(dataSet.getDataPoint(i), dataSet.getWeight(i), dataSet.getTargetValue(i));
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public KernelRLS clone() {
        return new KernelRLS(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes) {
        vecs = new ArrayList<Vec>();
        if (k.supportsAcceleration())
            kernelAccel = new DoubleArrayList();
        else
            kernelAccel = null;

        K = null;
        InvK = null;
        P = null;

        KExpanded = new DenseMatrix(100, 100);
        InvKExpanded = new DenseMatrix(100, 100);
        PExpanded = new DenseMatrix(100, 100);
        alphaExpanded = new double[100];
    }

    @Override
    public void update(DataPoint dataPoint, double weight, final double y_t) {
        /*
         * TODO a lot of temporary allocations are done in this code, but potentially
         * change size - investigate storing them as well.
         */
        Vec x_t = dataPoint.getNumericalValues();

        final DoubleList qi = k.getQueryInfo(x_t);
        final double k_tt = k.eval(0, 0, Arrays.asList(x_t), qi);

        if (K == null)// first point to be added
        {
            K = new SubMatrix(KExpanded, 0, 0, 1, 1);
            K.set(0, 0, k_tt);
            InvK = new SubMatrix(InvKExpanded, 0, 0, 1, 1);
            InvK.set(0, 0, 1 / k_tt);
            P = new SubMatrix(PExpanded, 0, 0, 1, 1);
            P.set(0, 0, 1);
            alphaExpanded[0] = y_t / k_tt;
            vecs.add(x_t);
            if (kernelAccel != null)
                kernelAccel.addAll(qi);
            return;
        }

        // Normal case
        DenseVector kxt = new DenseVector(K.rows());

        for (int i = 0; i < kxt.length(); i++)
            kxt.set(i, k.eval(i, x_t, qi, vecs, kernelAccel));

        // ALD test
        final Vec alphas_t = InvK.multiply(kxt);
        final double delta_t = k_tt - alphas_t.dot(kxt);
        final int size = K.rows();
        final double alphaConst = kxt.dot(new DenseVector(alphaExpanded, 0, size));
        if (delta_t > errorTolerance)// add to the dictionary
        {
            vecs.add(x_t);
            if (kernelAccel != null)
                kernelAccel.addAll(qi);

            if (size == KExpanded.rows())// we need to grow first
            {
                KExpanded.changeSize(size * 2, size * 2);
                InvKExpanded.changeSize(size * 2, size * 2);
                PExpanded.changeSize(size * 2, size * 2);

                alphaExpanded = Arrays.copyOf(alphaExpanded, size * 2);
            }

            Matrix.OuterProductUpdate(InvK, alphas_t, alphas_t, 1 / delta_t);
            K = new SubMatrix(KExpanded, 0, 0, size + 1, size + 1);
            InvK = new SubMatrix(InvKExpanded, 0, 0, size + 1, size + 1);
            P = new SubMatrix(PExpanded, 0, 0, size + 1, size + 1);

            // update bottom row and side columns
            for (int i = 0; i < size; i++) {
                K.set(size, i, kxt.get(i));
                K.set(i, size, kxt.get(i));

                InvK.set(size, i, -alphas_t.get(i) / delta_t);
                InvK.set(i, size, -alphas_t.get(i) / delta_t);

                // P is zeros, no change
            }
            // update bottom right corner
            K.set(size, size, k_tt);
            InvK.set(size, size, 1 / delta_t);
            P.set(size, size, 1.0);

            for (int i = 0; i < size; i++)
                alphaExpanded[i] -= alphas_t.get(i) * (y_t - alphaConst) / delta_t;
            alphaExpanded[size] = (y_t - alphaConst) / delta_t;
        } else// project onto dictionary
        {
            Vec q_t = P.multiply(alphas_t);
            q_t.mutableDivide(1 + alphas_t.dot(q_t));

            Matrix.OuterProductUpdate(P, q_t, alphas_t.multiply(P), -1);

            Vec InvKqt = InvK.multiply(q_t);
            for (int i = 0; i < size; i++)
                alphaExpanded[i] += InvKqt.get(i) * (y_t - alphaConst);
        }
    }

}
